Balance a binary search tree [DFS+Stack,DFS+Recursion]

Time: O(N); Space: O(H); medium

Given a binary search tree, return a balanced binary search tree with the same node values.

A binary search tree is balanced if and only if the depth of the two subtrees of every node never differ by more than 1.

If there is more than one answer, return any of them.

Example 1:

Input: root = {TreeNode} [1,null,2,null,3,null,4,null,null]

Output: [2,1,3,null,null,null,4] or [3,1,4,null,2,null,null]

Constraints:

  • The number of nodes in the tree is between 1 and 10^4.

  • The tree nodes will have distinct values between 1 and 10^5.

Hints:

  1. Convert the tree to a sorted array using an in-order traversal.

  2. Construct a new balanced tree from the sorted array recursively.

[1]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

1. DFS solution with stack

[2]:
class Solution1(object):
    """
    Time: O(N)
    Space: O(H)
    """
    def balanceBST(self, root):
        """
        :type root: TreeNode
        :rtype: TreeNode
        """
        def inorderTraversal(root):
            result, stk = [], [(root, False)]
            while stk:
                node, is_visited = stk.pop()
                if node is None:
                    continue
                if is_visited:
                    result.append(node.val)
                else:
                    stk.append((node.right, False))
                    stk.append((node, True))
                    stk.append((node.left, False))
            return result

        def sortedArrayToBst(arr):
            ROOT, LEFT, RIGHT = range(3)
            result = [None]
            stk = [(0, len(arr), ROOT, result)]
            while stk:
                i, j, update, ret = stk.pop()
                if i >= j:
                    continue
                mid = i + (j-i)//2
                node = TreeNode(arr[mid])
                if update == ROOT:
                    ret[0] = node
                elif update == LEFT:
                    ret[0].left = node
                else:
                    ret[0].right = node
                stk.append((mid+1, j, RIGHT, [node]))
                stk.append((i, mid, LEFT, [node]))
            return result[0]

        return sortedArrayToBst(inorderTraversal(root))
[14]:
s = Solution1()

root = TreeNode(1)
root.right = TreeNode(2)
root.right.right = TreeNode(3)
root.right.right.right = TreeNode(4)
res = s.balanceBST(root)

assert res.val == 3
assert res.left.val == 2
assert res.right.val == 4
assert res.left.left.val == 1

# assert res.val == 2
# assert res.left.val == 1
# assert res.right.val == 3
# assert res.right.right.val == 4

2. DFS solution with recursion

[15]:
class Solution2(object):

    def balanceBST(self, root):
        """
        :type root: TreeNode
        :rtype: TreeNode
        """
        def inorderTraversalHelper(node, arr):
            if not node:
                return
            inorderTraversalHelper(node.left, arr)
            arr.append(node.val)
            inorderTraversalHelper(node.right, arr)

        def sortedArrayToBstHelper(arr, i, j):
            if i >= j:
                return None
            mid = i + (j-i)//2
            node = TreeNode(arr[mid])

            node.left = sortedArrayToBstHelper(arr, i, mid)

            node.right = sortedArrayToBstHelper(arr, mid+1, j)

            return node

        arr = []
        inorderTraversalHelper(root, arr)

        return sortedArrayToBstHelper(arr, 0, len(arr))
[16]:
s = Solution2()

root = TreeNode(1)
root.right = TreeNode(2)
root.right.right = TreeNode(3)
root.right.right.right = TreeNode(4)
res = s.balanceBST(root)

assert res.val == 3
assert res.left.val == 2
assert res.right.val == 4
assert res.left.left.val == 1

# assert res.val == 2
# assert res.left.val == 1
# assert res.right.val == 3
# assert res.right.right.val == 4